{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SVI Part IV: Tips and Tricks\n",
    "\n",
    "The three SVI tutorials leading up to this one ([Part I](http://pyro.ai/examples/svi_part_i.html), [Part II](http://pyro.ai/examples/svi_part_ii.html), & [Part III](http://pyro.ai/examples/svi_part_iii.html)) go through\n",
    "the various steps involved in using Pyro to do variational\n",
    "inference.\n",
    "Along the way we defined models and guides (i.e. variational distributions),\n",
    "setup variational objectives (in particular [ELBOs](https://docs.pyro.ai/en/dev/inference_algos.html?highlight=elbo#module-pyro.infer.elbo)), \n",
    "and constructed optimizers ([pyro.optim](http://docs.pyro.ai/en/dev/optimization.html)). \n",
    "The effect of all this machinery is to cast Bayesian inference as a *stochastic optimization problem*. \n",
    "\n",
    "This is all very useful, but in order to arrive at our ultimate goal—learning model parameters, inferring approximate posteriors, making predictions with the posterior predictive distribution, etc.—we need to successfully solve this optimization problem. \n",
    "Depending on the details of the particular problem—for example the dimensionality of the latent space, whether we have discrete latent variables, and so on—this can be easy or hard. \n",
    "In this tutorial we cover a few tips and tricks we expect to be generally useful for users doing variational inference in Pyro. *ELBO not converging!? Running into NaNs!?* Look below for possible solutions!  \n",
    "\n",
    "#### Pyro Forum\n",
    "\n",
    "If you’re still having trouble with optimization after reading this tutorial, please don’t hesitate to ask a question on our [forum](https://forum.pyro.ai/)!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Start with a small learning rate\n",
    "\n",
    "While large learning rates might be appropriate for some problems, it's usually good practice to start with small learning rates like $10^{-3}$\n",
    "or $10^{-4}$:\n",
    "```python\n",
    "optimizer = pyro.optim.Adam({\"lr\": 0.001})\n",
    "```\n",
    "This is because ELBO gradients are *stochastic*, and potentially high variance, so large learning rates can quickly lead to regions of model/guide parameter space that are numerically unstable or otherwise undesirable.\n",
    "\n",
    "You can try a larger learning rate once you have achieved stable\n",
    "ELBO optimization using a smaller learning rate. \n",
    "This is often a good idea because excessively small learning rates can lead to poor optimization. \n",
    "In particular small learning rates can lead to getting stuck in poor local optima of the ELBO."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Use Adam or ClippedAdam by default\n",
    "\n",
    "Use [Adam](http://docs.pyro.ai/en/stable/optimization.html?highlight=clippedadam#pyro.optim.pytorch_optimizers.Adam)\n",
    "or [ClippedAdam](http://docs.pyro.ai/en/stable/optimization.html?highlight=clippedadam#pyro.optim.optim.ClippedAdam) by default when doing Stochastic Variational Inference. Note that `ClippedAdam` is just a convenient extension of `Adam` that provides built-in support for learning rate decay and gradient clipping.\n",
    "\n",
    "The basic reason these optimization algorithms often do well in the context of variational inference is that the smoothing they provide via per-parameter momentum is often essential when the optimization problem is very stochastic. Note that in SVI stochasticity can come from sampling latent variables, from subsampling data, or from both. \n",
    "\n",
    "In addition to tuning the learning rate in some cases it may be necessary to also tune the pair of `betas` hyperparameters that controls the momentum used by `Adam`. In particular for very stochastic models it may make sense to use higher values of $\\beta_1$:\n",
    "\n",
    "```python\n",
    "betas = (0.95, 0.999)\n",
    "```\n",
    "instead of \n",
    "```python\n",
    "betas = (0.90, 0.999)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Consider using a decaying learning rate\n",
    "\n",
    "While a moderately large learning rate can be useful at the beginning of optimization when you're far from the optimum and want to take large gradient steps, it's often useful to have a smaller learning rate later on so that you don't bounce around the optimum excessively without converging. \n",
    "One way to do this is to use the learning rate schedulers [provided](http://docs.pyro.ai/en/stable/optimization.html?highlight=scheduler#pyro.optim.lr_scheduler.PyroLRScheduler) by Pyro. For example usage see the code snippet [here](https://github.com/pyro-ppl/pyro/blob/a106882e8ffbfe6ac96f19aef9a218026482ed51/examples/scanvi/scanvi.py#L265). \n",
    "Another convenient way to do this is to use the [ClippedAdam](http://docs.pyro.ai/en/stable/optimization.html?highlight=clippedadam#pyro.optim.optim.ClippedAdam) optimizer that has built-in support for learning rate decay via the `lrd` argument:\n",
    "\n",
    "```python\n",
    "num_steps = 1000\n",
    "initial_lr = 0.001\n",
    "gamma = 0.1  # final learning rate will be gamma * initial_lr\n",
    "lrd = gamma ** (1 / num_steps)\n",
    "optim = pyro.optim.ClippedAdam({'lr': initial_lr, 'lrd': lrd})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Make sure your model and guide distributions have the same support\n",
    "\n",
    "Suppose you have a distribution in your `model` with constrained support, e.g. a LogNormal distribution, which has support on the positive real axis:\n",
    "```python\n",
    "def model():\n",
    "    pyro.sample(\"x\", dist.LogNormal(0.0, 1.0))\n",
    "``` \n",
    "Then you need to ensure that the accompanying `sample` site in the `guide` has the same support:\n",
    "```python\n",
    "def good_guide():\n",
    "    loc = pyro.param(\"loc\", torch.tensor(0.0))\n",
    "    pyro.sample(\"x\", dist.LogNormal(loc, 1.0))\n",
    "``` \n",
    "If you fail to do this and use for example the following inadmissable guide:\n",
    "```python\n",
    "def bad_guide():\n",
    "    loc = pyro.param(\"loc\", torch.tensor(0.0))\n",
    "    # Normal may sample x < 0\n",
    "    pyro.sample(\"x\", dist.Normal(loc, 1.0))  \n",
    "```\n",
    "you will likely run into NaNs very quickly. \n",
    "This is because the `log_prob` of a LogNormal distribution evaluated at a sample `x` that satisfies `x<0` is undefined, and the `bad_guide` is likely to produce such samples.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. Constrain parameters that need to be constrained\n",
    "In a similar vein, you need to make sure that the parameters used to instantiate distributions are valid; otherwise you will quickly run into NaNs. \n",
    "For example the `scale` parameter of a Normal distribution needs to be positive. Thus the following `bad_guide` is problematic:\n",
    "```python\n",
    "def bad_guide():\n",
    "    scale = pyro.sample(\"scale\", torch.tensor(1.0))\n",
    "    pyro.sample(\"x\", dist.Normal(0.0, scale))\n",
    "``` \n",
    "while the following `good_guide` correctly uses a constraint to ensure positivity:\n",
    "```python\n",
    "from pyro.distributions import constraints\n",
    "\n",
    "def good_guide():\n",
    "    scale = pyro.sample(\"scale\", torch.tensor(0.05),               \n",
    "                        constraint=constraints.positive)\n",
    "    pyro.sample(\"x\", dist.Normal(0.0, scale))\n",
    "``` "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6. If you are having trouble constructing a custom guide, use an AutoGuide\n",
    "\n",
    "In order for a model/guide pair to lead to stable optimization a number of conditions need to be satisfied, some of which we have covered above. \n",
    "Sometimes it can be difficult to diagnose the reason for numerical instability or poor convergence. \n",
    "Among other reasons this is because the fundamental issue could arise in a number of different places: in the model, in the guide, or in the choice of optimization algorithm or hyperparameters. \n",
    "\n",
    "Sometimes the problem is actually in your model even though you think it's in the guide. \n",
    "Conversely, sometimes the problem is in your guide even though you think it's in the model or somewhere else. \n",
    "For these reasons it can be helpful to reduce the number of moving parts while you try to identify the underyling issue.\n",
    "One convenient way to do this is to replace your custom guide with a [pyro.infer.AutoGuide](http://docs.pyro.ai/en/stable/infer.autoguide.html#module-pyro.infer.autoguide). \n",
    "\n",
    "For example, if all the latent variables in your model are continuous, you can try a [pyro.infer.AutoNormal](http://docs.pyro.ai/en/stable/infer.autoguide.html#autonormal) guide.\n",
    "Alternatively, you can use MAP inference instead of full-blown variational inference. See the [MLE/MAP](http://pyro.ai/examples/mle_map.html) tutorial for further details. Once you have MAP inference working, there's good reason to believe that your model is setup correctly (at least as far as basic numerical stability is concerned). \n",
    "If you're interested in obtaining approximate posterior distributions, you can now follow-up with full-blown SVI. Indeed a natural order of operations might use the following sequence of increasingly flexible autoguides:\n",
    "\n",
    "[AutoDelta](http://docs.pyro.ai/en/stable/infer.autoguide.html#autodelta)   →  [AutoNormal](http://docs.pyro.ai/en/stable/infer.autoguide.html#autonormal)  →  [AutoLowRankMultivariateNormal](http://docs.pyro.ai/en/stable/infer.autoguide.html#autolowrankmultivariatenormal)\n",
    "\n",
    "If you find that you want a more flexible guide or that you want to take more control over how exactly the guide is defined, at this juncture you can proceed to build a custom guide. \n",
    "One way to go about doing this is to leverage [easy guides](http://pyro.ai/examples/easyguide.html), which strike a balance between the control of a fully custom guide and the automation of an autoguide.\n",
    "\n",
    "Also note that autoguides offer several initialization strategies and it may be necessary in some cases to experiment with these in order to get good optimization performance. \n",
    "One way to control initialization behavior is using the `init_loc_fn`.\n",
    "For example usage of `init_loc_fn`, including example usage for the easy guide API, see [here](https://github.com/pyro-ppl/pyro/blob/a106882e8ffbfe6ac96f19aef9a218026482ed51/examples/sparse_gamma_def.py#L202)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7. Parameter initialization matters: initialize guide distributions to have low variance\n",
    "\n",
    "Initialization in optimization problems can make all the difference between finding a good solution and failing catastrophically.\n",
    "It is difficult to come up with a comprehensive set of good practices for initialization, as good initialization schemes are often very problem dependent. \n",
    "In the context of Stochastic Variational Inference it is generally a good idea to initialize your guide distributions so that they have **low variance**. \n",
    "This is because the ELBO gradients you use to optimize the ELBO are stochastic. \n",
    "If the ELBO gradients you get at the beginning of ELBO optimization exhibit high variance, you may be led into numerically unstable or otherwise undesirable regions of parameter space. \n",
    "One way to guard against this potential hazard is to pay close attention to parameters in your guide that control variance. \n",
    "For example we would generally expect this to be a reasonably initialized guide:\n",
    "```python\n",
    "from pyro.distributions import constraints\n",
    "\n",
    "def good_guide():\n",
    "    scale = pyro.sample(\"scale\", torch.tensor(0.05),               \n",
    "                        constraint=constraints.positive)\n",
    "    pyro.sample(\"x\", dist.Normal(0.0, scale))\n",
    "``` \n",
    "while the following high-variance guide is very likely to lead to problems:\n",
    "```python\n",
    "def bad_guide():\n",
    "    scale = pyro.sample(\"scale\", torch.tensor(12345.6),               \n",
    "                        constraint=constraints.positive)\n",
    "    pyro.sample(\"x\", dist.Normal(0.0, scale))\n",
    "``` \n",
    "\n",
    "Note that the initial variance of autoguides can be controlled with the `init_scale` argument, see e.g. [here](http://docs.pyro.ai/en/stable/infer.autoguide.html?highlight=init_scale#autonormal) for `AutoNormal`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 8. Explore trade-offs controlled by `num_particles`, mini-batch size, etc.\n",
    "\n",
    "Optimization can be difficult if your ELBO exhibits large variance. \n",
    "One way you can try to mitigate this issue is to increase the number of particles used to compute each stochastic ELBO estimate:\n",
    "\n",
    "```python\n",
    "elbo = pyro.infer.Trace_ELBO(num_particles=10, \n",
    "                             vectorize_particles=True)\n",
    "```\n",
    "(Note that to use `vectorized_particles=True` you need to ensure your model and guide are properly vectorized; see the [tensor shapes tutorial](http://pyro.ai/examples/tensor_shapes.html) for best practices.)\n",
    "This results in lower variance gradients at the cost of more compute. \n",
    "If you are doing data subsampling, the mini-batch size offers a similar trade-off: larger mini-batch sizes reduce the variance at the cost of more compute. \n",
    "Although what's best is problem dependent, it's usually worth taking more gradient steps with fewer particles than fewer gradient steps with more particles. \n",
    "An important caveat to this is when you're running on a GPU, in which case (at least for some models) the cost of increasing `num_particles` or your mini-batch size may be sublinear, in which case increasing `num_particles` is likely more attractive.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 9. Use `TraceMeanField_ELBO` if applicable\n",
    "\n",
    "The basic `ELBO` implementation in Pyro, [Trace_ELBO](http://docs.pyro.ai/en/stable/inference_algos.html?highlight=tracemeanfield#pyro.infer.trace_elbo.Trace_ELBO), uses stochastic samples to estimate the KL divergence term. \n",
    "When analytic KL diverences are available, you may be able to lower ELBO variance by using analytic KL divergences instead. This functionality is provided by [TraceMeanField_ELBO](http://docs.pyro.ai/en/stable/inference_algos.html?highlight=tracemeanfield#pyro.infer.trace_elbo.Trace_ELBO)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10. Consider normalizing your ELBO\n",
    "\n",
    "By default Pyro computes a un-normalized ELBO, i.e. it computes the quantity that is a lower bound to the log evidence computed on the full set of data that is being conditioned on. \n",
    "For large datasets this can be a number of large magnitude. \n",
    "Since computers use finite precision (e.g. 32-bit floats) to do arithmetic, large numbers can be problematic for numerical stability, since they can lead to loss of precision, under/overflow, etc.\n",
    "For this reason it can be helpful in many cases to normalize your ELBO so that it is roughly order one. \n",
    "This can also be helpful for getting a rough feeling for how good your ELBO numbers are. \n",
    "For example if we have $N$ datapoints of dimension $D$ (e.g. $N$ real-valued vectors of dimension $D$) then we generally expect a reasonably well optimized ELBO to be order $N \\times D$. \n",
    "Thus if we renormalize our ELBO by a factor of $N \\times D$ we expect an ELBO of order one. \n",
    "While this is just a rough rule-of-thumb, if we use this kind of normalization and obtain ELBO values like $-123.4$ or $1234.5$ then something is probably wrong: perhaps our model is terribly mis-specified; perhaps our initialization is catastrophically bad, etc. \n",
    "For details on how you can scale your ELBO by a normalization constant see [this tutorial](http://pyro.ai/examples/custom_objectives.html#Example:-Scaling-the-Loss)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 11. Pay attention to scales\n",
    "\n",
    "Scales of numbers matter. \n",
    "They matter for at least two important reasons: \n",
    "i) scales can make or break a particular initialization scheme; \n",
    "ii) as discussed in the previous section, scales can have an impact on numerical precision and stability.\n",
    "\n",
    "To make this concrete suppose you are doing linear regression, i.e.\n",
    "you're learning a linear map of the form $Y = W @ X$. Often the data comes with particular units. \n",
    "For example some of the components of the covariate $X$ may be in units of dollars (e.g. house prices), while others may be in units of density (e.g. residents per square mile). \n",
    "Perhaps the the first covariate has typical values like $10^5$, while the second covariate has typical values like $10^2$. \n",
    "You should always pay attention when you encounter numbers that range across many orders of magnitude. \n",
    "In many cases it makes sense to normalize things so that they are order unity. \n",
    "For example you might measure house prices in units of $100,000.\n",
    "\n",
    "These sorts of data transformations can have a number of benefits for downstream modeling and inference. \n",
    "For example if you've normalized all of your covariates appropriately, it may be reasonable to set a simple \n",
    "isotropic prior on your weights\n",
    "\n",
    "```python\n",
    "pyro.sample(\"W\", dist.Normal(torch.zeros(2), torch.ones(2)))\n",
    "```\n",
    "instead of having to specify different prior covariances for different covariates\n",
    "```python\n",
    "prior_scale = torch.tensor([1.0e-5, 1.0e-2])\n",
    "pyro.sample(\"W\", dist.Normal(torch.zeros(2), prior_scale))\n",
    "```\n",
    "There are other benefits too. \n",
    "It now becomes easier to initialize appropriate parameters for your guide. \n",
    "It is also now much more likely that the default initializations used by a [pyro.infer.AutoGuide](http://docs.pyro.ai/en/stable/infer.autoguide.html#module-pyro.infer.autoguide) will work for your problem."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 12. Keep validation enabled\n",
    "\n",
    "By default Pyro enables validation logic that can be helpful in debugging models and guides. \n",
    "For example, validation logic will inform you when distribution parameters become invalid.\n",
    "Unless you have good reason to do otherwise, keep the validation logic enabled. \n",
    "Once you're satisfied with a model and inference procedure, you may wish to disable validation using [pyro.enable_validation](http://docs.pyro.ai/en/stable/primitives.html?highlight=enable_validation#pyro.primitives.enable_validation).\n",
    "\n",
    "Similarly in the context of `ELBOs` it is a good idea to set \n",
    "```python\n",
    "strict_enumeration_warning=True\n",
    "```\n",
    "when you are enumerating discrete latent variables."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 13. Tensor shape errors\n",
    "\n",
    "If you're running into tensor shape errors please make sure you have carefully read the [corresponding tutorial](http://pyro.ai/examples/tensor_shapes.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 14. Enumerate discrete latent variables if possible\n",
    "\n",
    "If your model contains discrete latent variables it may make sense to enumerate them out exactly, since this can significantly reduce ELBO variance. \n",
    "For more discussion see the [corresponding tutorial](http://pyro.ai/examples/enumeration.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 15. Some complex models can benefit from KL annealing\n",
    "\n",
    "The particular form of the ELBO encodes a trade-off between model fit via the expected log likelihood term and a prior regularization term via the KL divergence. \n",
    "In some cases the KL divergence can act as a barrier that makes it difficult to find good optima. \n",
    "In these cases it can help to anneal the relevant strength of the KL divergence term during optimization. For further discussion see the [deep markov model tutorial](http://pyro.ai/examples/dmm.html#The-Black-Magic-of-Optimization).\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 16. Consider clipping gradients or constraining parameters defensively\n",
    "\n",
    "Certain parameters in your model or guide may control distribution parameters that can be sensitive to numerical issues. \n",
    "For example, the `concentration` and `rate` parameters that defines a [Gamma](http://docs.pyro.ai/en/stable/distributions.html#gamma) distribution may exhibit such sensitivity. \n",
    "In these cases it may make sense to clip gradients or constrain parameters defensively. \n",
    "See [this code snippet](https://github.com/pyro-ppl/pyro/blob/dev/examples/sparse_gamma_def.py#L135) for an example of gradient clipping. \n",
    "For a simple example of \"defensive\" parameter constraints consider the `concentration` parameter of a `Gamma` distribution. \n",
    "This parameter must be positive: `concentration` > 0.\n",
    "If we want to ensure that `concentration` stays away from zero we can use a `param` statement with an appropriate constraint:\n",
    "\n",
    "```python\n",
    "from pyro.distributions import constraints\n",
    "\n",
    "concentration = pyro.param(\"concentration\", torch.tensor(0.5),\n",
    "                           constraints.greater_than(0.001))\n",
    "```\n",
    "These kinds of tricks can help ensure that your models and guides stay away from numerically dangerous parts of parameter space."
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Raw Cell Format",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
